from scene.dataset_readers import getNerfppNorm, BasicPointCloud, SH2RGB, storePly, fetchPly, SceneInfo, focal2fov, fov2focal, CameraInfo
from steps_params import args
import numpy as np
import math
import os
import json
from PIL import Image
from pathlib import Path

def readCamFromNerfTransforms(path, transformsfile, white_background, extension=".png"):
    cam_infos = []

    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)
        focal_length = contents["focal_len"]
        fovx = 2 * math.atan(contents["cx"] / focal_length)
        fovy = 2 * math.atan(contents["cy"] /  focal_length)

        frames = contents["frames"]
        for idx, frame in enumerate(frames):
            cam_name = os.path.join("gt_imgs", f"{frame['img_id']}{extension}")

            # NeRF 'transform_matrix' is a camera-to-world transform
            c2w = np.array(frame["transform_matrix"])
            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
            c2w[:3, 1:3] *= -1

            # get the world-to-camera transform and set R, T
            w2c = np.linalg.inv(c2w)
            R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
            T = w2c[:3, 3]

            image_path = os.path.join(path, cam_name)
            image_name = Path(cam_name).stem
            image = Image.open(image_path)

            im_data = np.array(image.convert("RGBA"))

            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])

            norm_data = im_data / 255.0
            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")

            FovY = fovy 
            FovX = fovx

            cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                            image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
            
    return cam_infos

path = "./assets/May"
white_background = False
extension = ".jpg"

test_cam_infos = readCamFromNerfTransforms(path, "transforms_val.json", white_background, extension)

nerf_normalization = getNerfppNorm(test_cam_infos)

ply_path = os.path.join(path, "points3d.ply")
if not os.path.exists(ply_path):
    # Since this data set has no colmap data, we start with random points
    num_pts = 100_000
    print(f"Generating random point cloud ({num_pts})...")
    
    # We create random points inside the bounds of the synthetic Blender scenes
    xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
    shs = np.random.random((num_pts, 3)) / 255.0
    pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))

    storePly(ply_path, xyz, SH2RGB(shs) * 255)
try:
    pcd = fetchPly(ply_path)
except:
    pcd = None

scene_info = SceneInfo(point_cloud=pcd,
                        train_cameras=test_cam_infos,
                        test_cameras=test_cam_infos,
                        nerf_normalization=nerf_normalization,
                        ply_path=ply_path)